rm(list = ls())
options(stringsAsFactors = FALSE)
seed_to_use <- 614
set.seed(seed_to_use)
library(data.table)
library(parallel)
library(BART)
library(pscl)
library(survey)

load("protests_econ_data.rdata")
load("imputed_protests.rdata")

# formulae to create 1. model.data.frame and 2. model.matrix
form1 <- as.formula(
  " treatment ~ lag1_treatment + (standardize_bop) + (std_cab_usd) + 
  (lag1_standardize_bop) + (lag1_std_cab_usd) + (log(lag1_gdp_pc)) + 
  (log(gdp_pc)) + gdp_pct_change + lag1_gdp_pct_change +
  factor(country) + democ + lag1_democ + total_protests + year")
form2 <- as.formula(
  "  ~ lag1_treatment + (standardize_bop) + (std_cab_usd) + 
  gdp_pct_change + lag1_gdp_pct_change +
  (lag1_standardize_bop) + (lag1_std_cab_usd) + (log(lag1_gdp_pc)) + 
  (log(gdp_pc)) + factor(country) + 
  democ + lag1_democ")

## function to estimate msm ----
# input: data to use, formulae from above, and cutpoint to binarize data
# output: three model coefficients, 
# and a data.table with treatments, outcomes, and weights
binarize_treatment_estimate_iptw<- function(
  data_to_use = imputed_protests, 
  formula1 = form1,
  formula2 = form2,
  cutpoint = mean(imputed_protests$e3, na.rm = TRUE)
){

  data_to_use[, treatment_mean := (e3 > cutpoint)]
  data_to_use$treatment <- as.integer(data_to_use$treatment_mean)
  data_to_use[, treatment := as.integer(treatment_mean)]

  data_to_use[, 
    lag1_treatment:= shift(treatment, n = 1), by = country]
  data_to_use[, 
    lag2_treatment:= shift(treatment, n = 2), by = country]
  protests_econ_data2 <- data_to_use
  protests_econ_data2[, time := (year - min(year)) + 1L, by = country]
  
  # making the propensity score and weights
  m_df <- model.frame(form1, data = protests_econ_data2)
  m_mat <- model.matrix(form2, data = protests_econ_data2)
  m1 <- BART::lbart(x.train = m_mat, y.train = m_df$treatment, 
      sparse = TRUE, printevery = 2000L)
  scores1_mat <- plogis(m1$yhat.train)
  pscores1_mu <-  apply(scores1_mat, 2, mean)
  p1 <- sum(protests_econ_data2$treatment)
  p0 <- length(
    protests_econ_data2$treatment) - sum(protests_econ_data2$treatment)
  num_1 <- p1/length(protests_econ_data2$treatment)
  num_0 <- p0/length(protests_econ_data2$treatment)
  m_dt <- as.data.table(m_df)
  m_dt[,ipw1 := ifelse(
    treatment == 1, num_1/pscores1_mu, num_0/(1-pscores1_mu))]
  ipw1_trim <- quantile(m_dt$ipw1, probs = c(.01, .99))
  
  # trimming outliers
  protests_econ_data3 <- m_dt
  setnames(protests_econ_data3, c("factor(country)"), c("country"))
  protests_econ_data3[, ipw1 := ifelse(ipw1 < ipw1_trim[1], ipw1_trim[1], ipw1)]
  protests_econ_data3[, ipw1 := ifelse(ipw1 > ipw1_trim[2], ipw1_trim[2], ipw1)]
  
  # final dataset
  protests_econ_data4 <- protests_econ_data3[,.(ipw1, treatment,
     country, year, total_protests, democ)]
  treatment_weights <- protests_econ_data4[,.(
    year = year,
    cumu_ipw1 = cumprod(ipw1), 
    treatment = treatment,
    lag_1_treatment = shift(treatment, n = 1, "lag", fill=NA),
    lag_2_treatment = shift(treatment, n = 2, "lag", fill=NA),
    total_protests = total_protests), by = country]
  treatment_weights[, lag_2_treatment_na_omit := ifelse(
    is.na(lag_2_treatment), 0, lag_2_treatment)]
  treatment_weights[, 
    marginalize_treatment := cumsum(
      lag_2_treatment_na_omit), by = country]
  
  make_formula <- as.formula(
    "total_protests ~ treatment + lag_1_treatment + marginalize_treatment")

  t0 <- glm(make_formula, data = treatment_weights, family = quasipoisson())
  t1 <- survey::svydesign(ids = ~factor(country), 
    data = treatment_weights, weights = ~ (cumu_ipw1))
  t1_model <- survey::svyglm(make_formula, design = t1, family=quasipoisson())
  t2 <- suppressWarnings( zeroinfl(make_formula, data = treatment_weights, weights = cumu_ipw1))

  no_weights <- coef(t0)
  quasi_poisson_weights <- coef(t1_model)
  zero_infl_pois <- coef(t2)
  # 
  list(
    no_weights = no_weights,
    quasi_poisson_weights = quasi_poisson_weights, 
    zero_infl_pois = zero_infl_pois,
    weights = treatment_weights)

}



boot_msm <- function(i, 
  df = imputed_protests, 
  cut = mean(imputed_protests$e3, na.rm = TRUE)
  ){
  
  cat("\n \n", i, "\n \n")
  set.seed(i)
  countries <- sample(unique(df$country), replace = TRUE)
  boot_data <- rbindlist(lapply(1:length(countries), function(x){
    cname <- countries[x]
    d2 <- df[df$country==cname]
    d2
  }))
  msm1 <- binarize_treatment_estimate_iptw(
  data_to_use = boot_data, cutpoint = cut)
  return(msm1)
  
}
boot_msm_vals <- mclapply(
  1:20000, boot_msm, mc.cores = 48, 
  cut = mean(imputed_protests$e3, na.rm = TRUE))

save(boot_msm_vals, file = "no_trade_msm.rdata")

boot_msm_vals_40 <- mclapply(
  1:20000, boot_msm, mc.cores = 48, 
  cut = quantile(imputed_protests$e3, .4, na.rm = TRUE))

save(boot_msm_vals_40, file = "no_trade_msm_40.rdata")



boot_msm_vals_60 <- mclapply(
  1:20000, boot_msm, mc.cores = 48, 
  cut = quantile(imputed_protests$e3, .6, na.rm = TRUE))

save(boot_msm_vals_60, file = "no_trade_msm_60.rdata")

